{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "orig_nbformat": 2,
    "colab": {
      "name": "IceApp_coco.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/ai-fast-track/icevision-gradio/blob/master/IceApp_coco.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iHu4AkVlnqGC",
        "colab_type": "text"
      },
      "source": [
        "# IceVision Deployment App:  COCO Dataset\n",
        "This example uses Faster RCNN trained weights using the [COCO dataset](https://airctic.github.io/icedata/coco/)\n",
        "\n",
        "About IceVision:\n",
        "\n",
        "- an Object-Detection Framework that connects to different libraries/frameworks such as Fastai, Pytorch Lightning, and Pytorch with more to come.\n",
        "\n",
        "- Features a Unified Data API with out-of-the-box support for common annotation formats (COCO, VOC, etc.)\n",
        "\n",
        "- Provides flexible model implementations with pluggable backbones"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xd_v-Xlzn3fC",
        "colab_type": "text"
      },
      "source": [
        "## Installing packages"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xywzI5XVjI2S",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!pip install icevision[inference]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2l2kdIv3x0dd",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!pip install icedata"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jGaD1jw0H9Tj",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!pip install gradio"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RNVETtzKn-cU",
        "colab_type": "text"
      },
      "source": [
        "## Imports"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2PDsjGDbHBZY",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from icevision.all import *\n",
        "import icedata\n",
        "import PIL, requests\n",
        "import torch\n",
        "from torchvision import transforms\n",
        "import gradio as gr"
      ],
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kAlNFd5JoAm8",
        "colab_type": "text"
      },
      "source": [
        "## Loading trained model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "oWqa2mF1Me7z",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class_map = icedata.coco.class_map()\n",
        "model = icedata.coco.trained_models.faster_rcnn_resnet50_fpn()"
      ],
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dtUlW1X4oL73",
        "colab_type": "text"
      },
      "source": [
        "## Defininig the predict() method\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jIzikSZLHBZo",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def predict(\n",
        "    model, image, detection_threshold: float = 0.5, mask_threshold: float = 0.5\n",
        "):\n",
        "    tfms_ = tfms.A.Adapter([tfms.A.Normalize()])\n",
        "    # Whenever you have images in memory (numpy arrays) you can use `Dataset.from_images`\n",
        "    infer_ds = Dataset.from_images([image], tfms_)\n",
        "\n",
        "    batch, samples = faster_rcnn.build_infer_batch(infer_ds)\n",
        "    preds = faster_rcnn.predict(\n",
        "        model=model,\n",
        "        batch=batch,\n",
        "        detection_threshold=detection_threshold\n",
        "    )\n",
        "    return samples[0][\"img\"], preds[0]"
      ],
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zWqaACGOoUCM",
        "colab_type": "text"
      },
      "source": [
        "## Defining the `show_preds` method: called by `gr.Interface(fn=show_preds, ...)`"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qJHmeIhwoSS5",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def show_preds(input_image, display_list, detection_threshold):\n",
        "    display_label = (\"Label\" in display_list)\n",
        "    display_bbox = (\"BBox\" in display_list)\n",
        "\n",
        "    if detection_threshold==0: detection_threshold=0.5\n",
        "    \n",
        "    img, pred = predict(model=model, image=input_image, detection_threshold=detection_threshold)\n",
        "    # print(pred)\n",
        "    img = draw_pred(img=img, pred=pred, class_map=class_map, denormalize_fn=denormalize_imagenet, display_label=display_label, display_bbox=display_bbox)\n",
        "    img = PIL.Image.fromarray(img)\n",
        "    # print(\"Output Image: \", img.size, type(img))\n",
        "    return img"
      ],
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yRvdtz40oim7",
        "colab_type": "text"
      },
      "source": [
        "## Gradio User Interface"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_tGwcfYUHBZr",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 67
        },
        "outputId": "638b8519-60c4-4b99-f055-fa81f4f7958a"
      },
      "source": [
        "display_chkbox = gr.inputs.CheckboxGroup([\"Label\", \"BBox\"], label=\"Display\")\n",
        "detection_threshold_slider = gr.inputs.Slider(minimum=0, maximum=1, step=0.1, default=0.5, label=\"Detection Threshold\")\n",
        "\n",
        "outputs = gr.outputs.Image(type=\"pil\")\n",
        "\n",
        "gr_interface = gr.Interface(fn=show_preds, inputs=[\"image\", display_chkbox,  detection_threshold_slider], outputs=outputs, title='IceApp - COCO')\n",
        "gr_interface.launch(inline=False, share=True, debug=True)\n"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().\n",
            "This share link will expire in 6 hours. If you need a permanent link, email support@gradio.app\n",
            "Running on External URL: https://39710.gradio.app\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "source": [
        "## Enjoy! \n",
        "If you have any questions, please feel free to [join us](https://discord.gg/JDBeZYK)"
      ],
      "cell_type": "markdown",
      "metadata": {
        "id": "NoW-55FRlkgv",
        "colab_type": "code",
        "colab": {}
      }
    }
  ]
}